Skip to content

Aadduri/refactor emb#72

Open
abhinadduri wants to merge 8 commits intomainfrom
aadduri/refactor_emb
Open

Aadduri/refactor emb#72
abhinadduri wants to merge 8 commits intomainfrom
aadduri/refactor_emb

Conversation

@abhinadduri
Copy link
Copy Markdown
Collaborator

adds an argument balance_outliers for limiting outlier condition groups like control cells

abhinadduri and others added 8 commits February 23, 2026 08:12
…DataModule

- Add collate_dtype param to PerturbationDataset for float16/float32 tensor casting
- Wire collate_dtype through PerturbationDataModule to all dataset constructors

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e_controls

When n_samples > pool_size (e.g., observational data with rare cell types
having only 2 control cells but sentence_len=64), the old tail+head wrap
only wrapped once, returning fewer elements than requested. This caused
IndexError in __getitems__ during multi-worker DataLoader training.

Use modular arithmetic to wrap around the pool as many times as needed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…n conditions

Caps any condition exceeding the median sentence count using a rolling
window that advances each epoch, so all cells are eventually seen.
Applied only to training dataloaders; val/test remain unbalanced.

Bumps version to 0.11.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the version to 0.11.0 and transitions the project from print statements to a structured logging framework. It introduces several performance and memory optimizations, including a collate_dtype parameter to cast tensors to lower precision and a configurable pin_memory option for the dataloader. Additionally, a balance_outliers feature was added to the sampler to downsample over-represented perturbations using a rolling window. Feedback highlights potential issues with the new default values for collate_dtype and pin_memory, and suggests improvements for logging efficiency and code redundancy.

use_consecutive_loading: bool = False,
h5_open_kwargs: dict | None = None,
show_progress: bool = True,
collate_dtype: str = "float16",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The default value for collate_dtype is set to "float16". This is a significant change from the previous implicit default of float32 (via torch.FloatTensor). While this reduces memory usage, it may lead to precision issues or compatibility problems with models expecting float32 inputs. Consider if "float32" would be a safer default for a general-purpose dataloader, or ensure this change is clearly documented.

toml_config_path: str,
batch_size: int = 128,
num_workers: int = 8,
pin_memory: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default for pin_memory has been changed from True (hardcoded in the previous version) to False. Memory pinning is generally recommended when training on GPUs as it speeds up data transfer from CPU to GPU. If the primary use case is GPU training, consider keeping the default as True.

Comment on lines +261 to +265
logger.info(
f"balance_outliers: cap={cap} (median). "
f"{n_capped}/{len(unique_codes)} conditions capped. "
f"Sentences: {total_before} -> {total_after} (epoch {epoch})."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is recommended to use lazy interpolation in logging calls (passing arguments to the logger) rather than f-strings. This avoids the overhead of string formatting if the log level is disabled and maintains consistency with the logging style used in other parts of the PR (e.g., in filter_on_target_knockdown.py).

        logger.info(
            "balance_outliers: cap=%d (median). %d/%d conditions capped. Sentences: %d -> %d (epoch %d).",
            cap,
            n_capped,
            len(unique_codes),
            total_before,
            total_after,
            epoch,
        )

Comment on lines +103 to +105
self.output_space = _OUTPUT_SPACE_ALIASES.get(
kwargs.get("output_space", "gene"), kwargs.get("output_space", "gene")
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The call to kwargs.get("output_space", "gene") is redundant as it is executed twice. Extracting it into a variable makes the code cleaner.

        output_space = kwargs.get("output_space", "gene")
        self.output_space = _OUTPUT_SPACE_ALIASES.get(output_space, output_space)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant